#!/usr/bin/env python3

import numpy as np

# costs
def c_alg(x, lam, b):
    """Alg cost."""
    return np.where(x < lam, x, lam + b)

def c_opt(x, b):
    """OPT cost."""
    return np.minimum(x, b)

# weights on [y-h1, y+h2]
def w_linear(x, y, h1, h2):
    """Linear."""
    w = np.zeros_like(x)
    L = (x >= y - h1) & (x <= y)
    R = (x > y) & (x <= y + h2)
    w[L] = 1 - (y - x[L]) / h1
    w[R] = 1 - (x[R] - y) / h2
    return w

def w_gauss(x, y, h1, h2):
    """Gaussian, truncated."""
    s = (h1 + h2) / 8.0
    w = np.exp(-0.5 * ((x - y) / s) ** 2)
    return np.where((x >= y - h1) & (x <= y + h2), w, 0.0)

# densities on [y-h1, y+h2] (normalized)
def mu_linear(x, y, h1, h2):
    """Linear."""
    mu = np.zeros_like(x)
    L = (x >= y - h1) & (x <= y)
    R = (x > y) & (x <= y + h2)
    mu[L] = (x[L] - (y - h1)) / h1
    mu[R] = ((y + h2) - x[R]) / h2
    Z = np.trapz(mu, x)
    return mu / Z if Z > 0 else mu

def mu_gauss(x, y, h1, h2):
    """Gaussian, truncated."""
    s = (h1 + h2) / 8.0
    mu = np.exp(-0.5 * ((x - y) / s) ** 2)
    mu = np.where((x >= y - h1) & (x <= y + h2), mu, 0.0)
    Z = np.trapz(mu, x)
    return mu / Z if Z > 0 else mu

def mu_uniform(x, y, h1, h2):
    """Flat."""
    mu = np.zeros_like(x)
    S = (x >= y - h1) & (x <= y + h2)
    mu[S] = 1.0
    Z = np.trapz(mu, x)
    return mu / Z if Z > 0 else mu

# regret vs IR (r)
def _perf_ir(x, b, r):
    thr = min((b * r) / (r - 1.0), b * (r - 1.0))
    return np.where(
        x < b, 1.0,
        np.where(x < thr, x / b, r / (r - 1.0))
    )

def dist(x, lam, b, w_fn, y, h1, h2, r):
    """(alg - IR)*w."""
    perf_a = c_alg(x, lam, b) / c_opt(x, b)
    return (perf_a - _perf_ir(x, b, r)) * w_fn(x, y, h1, h2)

def d_max(lam, x, b, w_fn, y, h1, h2, r):
    return np.max(dist(x, lam, b, w_fn, y, h1, h2, r))

def d_avg(lam, x, b, w_fn, y, h1, h2, r):
    return np.mean(dist(x, lam, b, w_fn, y, h1, h2, r))

# CVaR objective (piecewise min form)
def cvar_obj(lam, x, b, mu, alpha, return_all=False):
    """Min over 3 cases."""
    assert 0 <= alpha < 1
    T = lam
    dx = np.gradient(x)
    M = np.cumsum(mu * dx)
    M /= M[-1]
    t_star = np.interp(alpha, M, x)
    M_T = np.interp(T, x, M)
    q = 1.0 - M_T

    x1 = x[(x >= t_star) & (x < T)]
    mu1 = mu[(x >= t_star) & (x < T)]
    I1 = np.trapz(x1 * mu1, x1) if x1.size else 0.0
    case1 = (I1 + (T + b) * q) / (1.0 - alpha)

    case2 = min(T + b * q / (1.0 - alpha), T + b)
    case3 = T + b

    if return_all:
        return case1, case2, case3
    return min(case1, case2, case3)

# discrete search for λ
def find_lambda(x, b, w_fn, y, h1, h2, alpha, objective, mu, r, dump=False):
    L = b / (r - 1.0)
    U = b * (r - 1.0)
    lam_grid = x[(x >= L) & (x <= U)]
    best_v, best_l = float("inf"), None
    vals = []
    for l in lam_grid:
        if objective == "max":
            v = d_max(l, x, b, w_fn, y, h1, h2, r)
        elif objective == "avg":
            v = d_avg(l, x, b, w_fn, y, h1, h2, r)
        elif objective == "cvar":
            v = cvar_obj(l, x, b, mu, alpha)
        else:
            raise ValueError("objective")
        vals.append(v)
        if v < best_v:
            best_v, best_l = v, l
    if dump:
        import matplotlib.pyplot as plt
        plt.plot(lam_grid, vals); plt.tight_layout(); plt.savefig("lambda_scan.png"); plt.close()
    return float(best_l)

# main experiment (Monte Carlo)
def ski_rental_experiment(
    b, r, w_fn, n_trials=100, mu_fn=mu_gauss, z=4.0, seed=42, delta=0.9,
    alphas=(0.1, 0.5, 0.9), save_csv=None, plot_path=None
):
    rng = np.random.default_rng(seed)
    alg_specs = [
        ("MAX", "max", None),
        ("AVG", "avg", None),
        (f"CVaR_{alphas[0]}", "cvar", alphas[0]),
        (f"CVaR_{alphas[1]}", "cvar", alphas[1]),
        (f"CVaR_{alphas[2]}", "cvar", alphas[2]),
    ]
    bases = ["A_b", "A_b(r-1)", "A_b/(r-1)", "A_mid"]

    R = {"baselines": {k: [] for k in bases}, "baselines_E": {k: [] for k in bases}, "E_opt": []}
    for name, _, _ in alg_specs:
        R[name] = {"avg_ALG": [], "imp_vs_A_b": [], "imp_vs_A_b(r-1)": [], "imp_vs_A_b/(r-1)": [], "imp_vs_A_mid": [], "E_cost": []}

    rows = []
    eps = 1e-9
    for t in range(n_trials):
        y = rng.uniform(b / z, b * z)
        h1 = delta * y
        h2 = delta * y

        x = np.linspace(eps, b * r, 1000)
        mu = mu_fn(x, y, h1, h2)

        lo, hi = max(eps, y * (1.0 - delta)), y * (1.0 + delta)
        x_eval = np.linspace(lo, hi, 500)
        denom = c_opt(x_eval, b)

        lam_A_b = b
        lam_A_br1 = b * (r - 1.0)
        lam_A_bov = b / (r - 1.0)
        lam_A_mid = 0.5 * b * r + b

        r_Ab = c_alg(x_eval, lam_A_b, b) / denom
        r_Abr1 = c_alg(x_eval, lam_A_br1, b) / denom
        r_Abo = c_alg(x_eval, lam_A_bov, b) / denom
        r_Amid = c_alg(x_eval, lam_A_mid, b) / denom

        R["baselines"]["A_b"].append(float(np.mean(r_Ab)))
        R["baselines"]["A_b(r-1)"].append(float(np.mean(r_Abr1)))
        R["baselines"]["A_b/(r-1)"].append(float(np.mean(r_Abo)))
        R["baselines"]["A_mid"].append(float(np.mean(r_Amid)))

        R["E_opt"].append(float(np.trapz(c_opt(x, b) * mu, x)))
        R["baselines_E"]["A_b"].append(float(np.trapz(c_alg(x, lam_A_b, b) * mu, x)))
        R["baselines_E"]["A_b(r-1)"].append(float(np.trapz(c_alg(x, lam_A_br1, b) * mu, x)))
        R["baselines_E"]["A_b/(r-1)"].append(float(np.trapz(c_alg(x, lam_A_bov, b) * mu, x)))
        R["baselines_E"]["A_mid"].append(float(np.trapz(c_alg(x, lam_A_mid, b) * mu, x)))

        for name, obj, a in alg_specs:
            lam = find_lambda(x, b, w_fn, y, h1, h2, a if obj == "cvar" else 0.5, obj, mu, r, False)
            ratio_alg = c_alg(x_eval, lam, b) / denom
            R[name]["avg_ALG"].append(float(np.mean(ratio_alg)))
            R[name]["imp_vs_A_b"].append(float(np.mean(ratio_alg < r_Ab)))
            R[name]["imp_vs_A_b(r-1)"].append(float(np.mean(ratio_alg < r_Abr1)))
            R[name]["imp_vs_A_b/(r-1)"].append(float(np.mean(ratio_alg < r_Abo)))
            R[name]["imp_vs_A_mid"].append(float(np.mean(ratio_alg < r_Amid)))
            E_alg = float(np.trapz(c_alg(x, lam, b) * mu, x))
            R[name]["E_cost"].append(E_alg)

            if save_csv is not None:
                rows.append({
                    "trial": t, "alg": name, "avg_ALG": R[name]["avg_ALG"][-1],
                    "imp_vs_A_b": R[name]["imp_vs_A_b"][-1],
                    "imp_vs_A_b(r-1)": R[name]["imp_vs_A_b(r-1)"][-1],
                    "imp_vs_A_b/(r-1)": R[name]["imp_vs_A_b/(r-1)"][-1],
                    "imp_vs_A_mid": R[name]["imp_vs_A_mid"][-1],
                    "avg_A_b": R["baselines"]["A_b"][-1],
                    "avg_A_b(r-1)": R["baselines"]["A_b(r-1)"][-1],
                    "avg_A_b/(r-1)": R["baselines"]["A_b/(r-1)"][-1],
                    "avg_A_mid": R["baselines"]["A_mid"][-1],
                    "E_cost": E_alg, "E_opt": R["E_opt"][-1],
                    "E_A_b": R["baselines_E"]["A_b"][-1],
                    "E_A_b(r-1)": R["baselines_E"]["A_b(r-1)"][-1],
                    "E_A_b/(r-1)": R["baselines_E"]["A_b/(r-1)"][-1],
                    "E_A_mid": R["baselines_E"]["A_mid"][-1],
                    "y": y, "h1": h1, "h2": h2
                })

    def _mean_ci_asym(samples, conf=0.95, n_boot=2000):
        arr = np.asarray(samples, dtype=float)
        n = arr.size
        if n == 0:
            return float("nan"), (0.0, 0.0)
        mean = float(np.mean(arr))
        if n == 1:
            return mean, (0.0, 0.0)
        idx = rng.integers(0, n, size=(n_boot, n))
        boots = np.mean(arr[idx], axis=1)
        a = (1.0 - conf) / 2.0
        lo, hi = np.quantile(boots, [a, 1.0 - a])
        return mean, (mean - float(lo), float(hi) - mean)

    S, CI = {}, {}

    for base in bases:
        m, e = _mean_ci_asym(R["baselines"][base]); S[f"avg_ratio_{base}"] = m; CI[f"avg_ratio_{base}"] = e
        m, e = _mean_ci_asym(R["baselines_E"][base]); S[f"E_cost_{base}"] = m; CI[f"E_cost_{base}"] = e

    m, e = _mean_ci_asym(R["E_opt"]); S["E_opt"] = m; CI["E_opt"] = e

    for name, _, _ in alg_specs:
        m, e = _mean_ci_asym(R[name]["avg_ALG"]); S[f"avg_ratio_{name}"] = m; CI[f"avg_ratio_{name}"] = e
        for base in bases:
            m2, e2 = _mean_ci_asym(R[name][f"imp_vs_{base}"])
            S[f"time_{name}_lt_{base}"] = m2
            CI[f"time_{name}_lt_{base}"] = e2
        mE, eE = _mean_ci_asym(R[name]["E_cost"]); S[f"E_cost_{name}"] = mE; CI[f"E_cost_{name}"] = eE

    if save_csv is not None:
        import csv
        cols = [
            "trial","alg","avg_ALG","imp_vs_A_b","imp_vs_A_b(r-1)","imp_vs_A_b/(r-1)","imp_vs_A_mid",
            "avg_A_b","avg_A_b(r-1)","avg_A_b/(r-1)","avg_A_mid",
            "E_cost","E_opt","E_A_b","E_A_b(r-1)","E_A_b/(r-1)","E_A_mid","y","h1","h2"
        ]
        with open(save_csv, "w", newline="") as f:
            w = csv.DictWriter(f, fieldnames=cols); w.writeheader(); w.writerows(rows)

    if plot_path is not None:
        import matplotlib.pyplot as plt
        from matplotlib.ticker import FuncFormatter
        labels = bases + [s[0] for s in alg_specs]
        means = [S[f"avg_ratio_{k}"] for k in bases] + [S[f"avg_ratio_{s[0]}"] for s in alg_specs]
        errs = [max(*CI[f"avg_ratio_{k}"]) for k in bases] + [max(*CI[f"avg_ratio_{s[0]}"]) for s in alg_specs]
        x = np.arange(len(labels))
        fig, ax = plt.subplots(figsize=(9.0, 3.6))
        ax.bar(x, means, yerr=errs, capsize=3, width=0.35)
        ylo = min(m - e for m, e in zip(means, errs)); yhi = max(m + e for m, e in zip(means, errs))
        pad = 0.04 * (yhi - ylo) if yhi > ylo else 0.05 * (yhi if yhi != 0 else 1.0)
        ax.set_ylim(ylo - pad, yhi + pad)
        ax.set_xticks(x); ax.set_xticklabels(labels, rotation=25, ha="right")
        ax.set_ylabel("Avg performance ratio")
        ax.set_title("Ski Rental (mean ± 95% CI)")
        ax.yaxis.set_major_formatter(FuncFormatter(lambda v, _: f"{v:.2f}".rstrip("0").rstrip(".")))
        ax.grid(axis="y", alpha=0.3); ax.margins(x=0.02); fig.tight_layout()
        fig.savefig(plot_path if plot_path.lower().endswith(".pdf") else plot_path + ".pdf", bbox_inches="tight")
        fig.savefig(plot_path if plot_path.lower().endswith(".png") else plot_path + ".png", dpi=220, bbox_inches="tight")
        plt.close(fig)

    return S, CI, R

# example run
if __name__ == "__main__":
    S, CI, _ = ski_rental_experiment(
        b=10, r=5.0, w_fn=w_gauss, n_trials=100, mu_fn=mu_gauss, z=7, seed=42, delta=0.9,
        alphas=(0.1, 0.5, 0.9), plot_path=None
    )
    # minimal print
    keys = ["A_b","A_b(r-1)","A_b/(r-1)","A_mid","MAX","AVG","CVaR_0.1","CVaR_0.5","CVaR_0.9"]
    for k in keys:
        m, (lo, hi) = S.get(f"avg_ratio_{k}", np.nan), CI.get(f"avg_ratio_{k}", (0,0))
        print(f"{k:10s}: {m:.4f}  +{hi:.4f}/-{lo:.4f}")
